from enum import Enum
from typing import Any, Dict, List, Optional, Union

import numpy as np
from pydantic import BaseModel

from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.utils import maximal_marginal_relevance


class SearchType(str, Enum):
    similarity = "similarity"
    mmr = "mmr"


class DocArrayRetriever(BaseRetriever, BaseModel):
    """
    Retriever class for DocArray Document Indices.

    Currently, supports 5 backends:
    InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex,
    ElasticDocIndex, and WeaviateDocumentIndex.

    Attributes:
        index: One of the above-mentioned index instances
        embeddings: Embedding model to represent text as vectors
        search_field: Field to consider for searching in the documents.
            Should be an embedding/vector/tensor.
        content_field: Field that represents the main content in your document schema.
            Will be used as a `page_content`. Everything else will go into `metadata`.
        search_type: Type of search to perform (similarity / mmr)
        filters: Filters applied for document retrieval.
        top_k: Number of documents to return
    """

    index: Any
    embeddings: Embeddings
    search_field: str
    content_field: str
    search_type: SearchType = SearchType.similarity
    top_k: int = 1
    filters: Optional[Any] = None

    class Config:
        """Configuration for this pydantic object."""

        arbitrary_types_allowed = True

    def get_relevant_documents(self, query: str) -> List[Document]:
        """Get documents relevant for a query.

        Args:
            query: string to find relevant documents for

        Returns:
            List of relevant documents
        """
        query_emb = np.array(self.embeddings.embed_query(query))

        if self.search_type == SearchType.similarity:
            results = self._similarity_search(query_emb)
        elif self.search_type == SearchType.mmr:
            results = self._mmr_search(query_emb)
        else:
            raise ValueError(
                f"Search type {self.search_type} does not exist. "
                f"Choose either 'similarity' or 'mmr'."
            )

        return results

    def _search(
        self, query_emb: np.ndarray, top_k: int
    ) -> List[Union[Dict[str, Any], Any]]:
        """
        Perform a search using the query embedding and return top_k documents.

        Args:
            query_emb: Query represented as an embedding
            top_k: Number of documents to return

        Returns:
            A list of top_k documents matching the query
        """

        from docarray.index import ElasticDocIndex, WeaviateDocumentIndex

        filter_args = {}
        search_field = self.search_field
        if isinstance(self.index, WeaviateDocumentIndex):
            filter_args["where_filter"] = self.filters
            search_field = ""
        elif isinstance(self.index, ElasticDocIndex):
            filter_args["query"] = self.filters
        else:
            filter_args["filter_query"] = self.filters

        if self.filters:
            query = (
                self.index.build_query()  # get empty query object
                .find(
                    query=query_emb, search_field=search_field
                )  # add vector similarity search
                .filter(**filter_args)  # add filter search
                .build(limit=top_k)  # build the query
            )
            # execute the combined query and return the results
            docs = self.index.execute_query(query)
            if hasattr(docs, "documents"):
                docs = docs.documents
            docs = docs[:top_k]
        else:
            docs = self.index.find(
                query=query_emb, search_field=search_field, limit=top_k
            ).documents
        return docs

    def _similarity_search(self, query_emb: np.ndarray) -> List[Document]:
        """
        Perform a similarity search.

        Args:
            query_emb: Query represented as an embedding

        Returns:
            A list of documents most similar to the query
        """
        docs = self._search(query_emb=query_emb, top_k=self.top_k)
        results = [self._docarray_to_langchain_doc(doc) for doc in docs]
        return results

    def _mmr_search(self, query_emb: np.ndarray) -> List[Document]:
        """
        Perform a maximal marginal relevance (mmr) search.

        Args:
            query_emb: Query represented as an embedding

        Returns:
            A list of diverse documents related to the query
        """
        docs = self._search(query_emb=query_emb, top_k=20)

        mmr_selected = maximal_marginal_relevance(
            query_emb,
            [
                doc[self.search_field]
                if isinstance(doc, dict)
                else getattr(doc, self.search_field)
                for doc in docs
            ],
            k=self.top_k,
        )
        results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected]
        return results

    def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document:
        """
        Convert a DocArray document (which also might be a dict)
        to a langchain document format.

        DocArray document can contain arbitrary fields, so the mapping is done
        in the following way:

        page_content <-> content_field
        metadata <-> all other fields excluding
            tensors and embeddings (so float, int, string)

        Args:
            doc: DocArray document

        Returns:
            Document in langchain format

        Raises:
            ValueError: If the document doesn't contain the content field
        """

        fields = doc.keys() if isinstance(doc, dict) else doc.__fields__

        if self.content_field not in fields:
            raise ValueError(
                f"Document does not contain the content field - {self.content_field}."
            )
        lc_doc = Document(
            page_content=doc[self.content_field]
            if isinstance(doc, dict)
            else getattr(doc, self.content_field)
        )

        for name in fields:
            value = doc[name] if isinstance(doc, dict) else getattr(doc, name)
            if (
                isinstance(value, (str, int, float, bool))
                and name != self.content_field
            ):
                lc_doc.metadata[name] = value

        return lc_doc

    async def aget_relevant_documents(self, query: str) -> List[Document]:
        raise NotImplementedError
